# 模型参数配置文件

import os

# 模型相关参数
TIME_STEPS = 128
CHANNELS = 6
INPUT_SHAPE = (TIME_STEPS, CHANNELS)


# 数据路径参数
TRAIN_DATA_PATH = os.path.join(os.getcwd(), 'train_data')
TRAIN_LABELS_PATH = None

# 动作标签映射
MOTION_NAMES = ['flip','RightAngle', 'SharpAngle', 'Lightning', 'Triangle', 'Letter_h', 'letter_R', 'letter_W', 'letter_phi', 'Circle', 'UpAndDown', 'Horn', 'Wave', 'NoMotion']
NUM_CLASSES = len(MOTION_NAMES)
MOTION_TO_LABEL = {name: i for i, name in enumerate(MOTION_NAMES)}

# 数据集参数
DEF_FILE_FORMAT = ".txt"
DEF_FILE_MAX = 999
DEF_USE_COLS = (0, 1, 2, 3, 4, 5)  # 六轴数据
DEF_N_ROWS = None  # 读取所有行

# 训练参数
EPOCHS = 1000
BATCH_SIZE = 32

# 模型保存参数
DEF_MODEL_NAME = 'model.h5'
DEF_MODEL_H_NAME = 'weights.h'

# 预训练模型路径（如果存在，则在此基础上继续训练）
# PRETRAINED_MODEL_PATH =  os.path.join(os.getcwd(), 'model.h5')  # 如果设置有效路径，则加载该模型继续训练
PRETRAINED_MODEL_PATH = None